{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "02f1617c-e8ee-4123-8416-5ec1a7974cab",
   "metadata": {},
   "source": [
    "This notebook explores the end-to-end benchmarking pipeline, including:\n",
    "\n",
    "1. Initializing dataset and dataloader\n",
    "2. Initializing model, either from our benchmark model definition or your own use cases\n",
    "3. Running the model given input data\n",
    "4. Defining criterion (e.g., MSE, RMSE)\n",
    "5. Benchmarking against validation (observation) and testing (forecasting model) data\n",
    "\n",
    "NOTE: This notebook does not contain the training pipeline..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6dbd3180-4d25-46b7-90fc-f8db79661301",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d8dc2f0a-5b6d-4382-87a8-fe2aecd1b786",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/burg/home/jn2808/.conda/envs/bench/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import xarray as xr\n",
    "import numpy as np\n",
    "from pathlib import Path\n",
    "from glob import glob\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "from chaosbench import dataset, config, utils, criterion\n",
    "from chaosbench.models import mlp, cnn, ae, fno, vit\n",
    "\n",
    "import logging\n",
    "logging.basicConfig(level=logging.INFO)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71e77b69-391b-4335-816b-5a03c32983fb",
   "metadata": {},
   "source": [
    "## Dataset Preparation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26f60dfa-2099-43f7-ade6-01f71b57428f",
   "metadata": {},
   "source": [
    "First of all, we are initializing our Dataset and Dataloader that are going to be used for training / evaluation processes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "562a1d04-95f7-4544-ac2a-d4f1bd092941",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/burg/home/jn2808/.conda/envs/bench/lib/python3.9/site-packages/gribapi/__init__.py:23: UserWarning: ecCodes 2.31.0 or higher is recommended. You are running version 2.30.0\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "# Specify train/val years + test benchmark\n",
    "train_years = np.arange(2016, 2022)\n",
    "val_years = np.arange(2022, 2023)\n",
    "\n",
    "# Also land + ocean variables to be included (acronyms are detailed in the project webpage)\n",
    "land_vars = []\n",
    "ocean_vars = []\n",
    "\n",
    "# Initialize Dataset objects\n",
    "N_STEP = 1\n",
    "LEAD_TIME = 1\n",
    "train_dataset = dataset.S2SObsDataset(\n",
    "    years=train_years, \n",
    "    n_step=N_STEP, \n",
    "    lead_time=LEAD_TIME, \n",
    "    land_vars=land_vars, \n",
    "    ocean_vars=ocean_vars\n",
    ")\n",
    "\n",
    "val_dataset = dataset.S2SObsDataset(\n",
    "    years=val_years, \n",
    "    n_step=N_STEP, \n",
    "    lead_time=LEAD_TIME, \n",
    "    land_vars=land_vars, \n",
    "    ocean_vars=ocean_vars\n",
    ")\n",
    "\n",
    "# test_dataset = dataset.S2SEvalDataset(s2s_name='ncep', years=val_years) ## OPTIONAL\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0de7c6a7-533a-45c0-b445-279cb1201446",
   "metadata": {},
   "source": [
    "You have the flexibility to define your own DataLoader here, including the batch_size, etc.."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9b6144b9-ee33-449b-82ed-371ba1bc61a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define your own Dataloader\n",
    "batch_size = 4\n",
    "\n",
    "train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)\n",
    "\n",
    "# test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) ## OPTIONAL\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5977a244-5138-418c-b4be-886fadfb6230",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Inspect a batch\n",
    "_, train_x, train_y = next(iter(train_dataloader))\n",
    "_, val_x, val_y = next(iter(val_dataloader))\n",
    "\n",
    "# _, test_x, test_y = next(iter(test_dataloader)) ## OPTIONAL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1bc1bfee-e65a-4f20-b4ac-28067d3fdf92",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train/val x: torch.Size([4, 60, 121, 240])\n",
      "train/val y: torch.Size([4, 1, 60, 121, 240])\n"
     ]
    }
   ],
   "source": [
    "print(f'train/val x: {train_x.shape}') # Each tensor has the shape of (batch_size, params, lat, lon)\n",
    "print(f'train/val y: {train_y.shape}') # Each tensor has the shape of (batch_size, step_size, params, lat, lon)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "1c4657f3-4bb1-4a13-8595-5f8a1d8b5054",
   "metadata": {},
   "outputs": [],
   "source": [
    "## OPTIONAL\n",
    "# print(f'test x: {test_x.shape}') # Each tensor has the shape of (batch_size, params, level, lat, lon)\n",
    "# print(f'test y: {test_y.shape}') # Each tensor has the shape of (batch_size, lead_time=44, params, level, lat, lon)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a4a210a-122b-4429-a7cf-e63e43cc9151",
   "metadata": {},
   "source": [
    "## Modeling"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fee9a8b0-80ab-472e-b239-55788adf4d06",
   "metadata": {},
   "source": [
    "Now that we have our Dataset and Dataloader setup, we can begin the modeling process. Our benchmark model architectures are defined under `chaosbench/models`\n",
    "\n",
    "As a starter, we can define an autoencoder..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b1e23cbe-cf32-407a-95a0-3536298f6391",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Specify model specifications\n",
    "\n",
    "model = cnn.UNet(input_size=train_x.shape[1], output_size=train_x.shape[1])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "036344fb-badd-49ef-9edc-22444c43a1e1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 60, 121, 240])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Run the model to get output\n",
    "preds = model(train_x)\n",
    "preds = preds.reshape(tuple([batch_size]) + tuple(torch.tensor(train_x.shape[1:])))\n",
    "preds.shape\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d86b75cc-9fd5-4121-b4f9-2fc8b6d5c180",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "9f5fac83-d23e-47c2-9919-e7804640b3f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# We define what error metrics we want to compute (e.g., RMSE)\n",
    "rmse = criterion.RMSE(lat_adjusted=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "99308b7a-424b-4c9a-a76e-a002c68f6ed8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0959651470184326\n"
     ]
    }
   ],
   "source": [
    "# Compute error\n",
    "preds = model(train_x)\n",
    "error = rmse(preds, train_x)\n",
    "\n",
    "print(error.item())\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "bench",
   "language": "python",
   "name": "bench"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
